rm(list=ls())
require(quanteda)
require(seededlda)
library(ggplot2)
library(furrr)
library(purrr)
library(knitr)
library(reshape2)
library(dplyr)

quanteda_options(threads = 60)
source("functions.R")

plan(multisession, workers = 32) 

dict_yaml <- yaml::read_yaml("dictionary/dictionary.yml")

seeds_level3 <- list()
for (lvl2 in names(dict_yaml$seeds)) {
  for (lvl3 in names(dict_yaml$seeds[[lvl2]])) {
    seeds_level3[[lvl3]] <- dict_yaml$seeds[[lvl2]][[lvl3]]
  }
}

dict <- dictionary(seeds_level3)

token_dir <- "data/"
token_files <- list.files(token_dir, pattern = "tokens_.*\\.RDS$", full.names = TRUE)

dfm_list <- future_map(token_files, function(f) {
  toks <- readRDS(f)
  toks <- tokens_remove(toks, "\\d", valuetype = "regex", min_nchar = 2)
  dfm(toks)
})

dfmt <- do.call(rbind, dfm_list)
dfmt <- dfmt |>  
  dfm_trim(min_termfreq = 10, 
           max_docfreq = 0.5,
           docfreq_type = "prop")  

# seeded lda modeling
set.seed(123)
lda <- textmodel_seededlda(dfmt, dict, residual = 2, 
                           batch_size = 0.1, auto_iter = TRUE,
                           verbose = TRUE)

terms(lda, 10)
knitr::kable(terms(lda))

lda_terms <- terms(lda)
kable_out <- kable(as.data.frame(lda_terms), format = "html")
writeLines(kable_out, "output/lda_terms.html")


#### about Israel for the seeds
dict_target <- dictionary(list(target = dict_yaml$target))

set.seed(123)
lda_target <- textmodel_seededlda(dfmt, dict_target, residual = 2, 
                           batch_size = 0.01, auto_iter = TRUE,
                           verbose = TRUE)

terms(lda_target, 100)
knitr::kable(terms(lda_target, 100))

lda_terms_target <- terms(lda_target, 100)
kable_out_target <- kable(as.data.frame(lda_terms_target), format = "html")
writeLines(kable_out_target, "output/lda_terms_target.html")




# topic distribution
df_probs <- as.data.frame(lda$theta)
df_probs$doc_id <- docnames(dfmt)


df_long <- melt(df_probs, id.vars = "doc_id", 
                variable.name = "topic", value.name = "prob")

lda_whole <- ggplot(df_long, aes(x = topic, y = prob)) +
  geom_boxplot() +
  labs(title = "Seeded LDA Topic Distributions", x = "Topic", y = "Probability") +
  theme_minimal()
ggsave(filename = "output/lda_whole.png",
       plot = lda_whole,
       width = 8,
       height = 5,
       units = "in",
       bg = "white",
       dpi = 600)

# by country
df_probs <- as.data.frame(lda$theta)
colnames(df_probs) <- c("dangerous", "hostile", "rational", "friend", "brotherhood", "other1", "other2") 

df_probs$doc_id <- docnames(dfmt)
df_probs$country <- sub("_.*", "", df_probs$doc_id)

df_country <- df_probs %>%
  group_by(country) %>%
  summarise(across(c("dangerous", "hostile", "rational", "friend", "brotherhood", "other1", "other2"), mean), .groups = "drop")

df_long_country <- melt(df_country, id.vars = "country",
                        variable.name = "topic", value.name = "prob")

lda_by_country <- ggplot(df_long_country, aes(x = topic, y = prob, fill = country)) +
  geom_bar(stat = "identity", position = "dodge") +
  labs(title = "Seeded LDA Topic Probabilities by Country",
       x = "Topic", y = "Average Probability") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))
ggsave(filename = "output/lda_by_country.png",
       plot = lda_by_country,
       width = 8,
       height = 5,
       units = "in",
       bg = "white",
       dpi = 600)
